import numpy as np
import matplotlib.pyplot as plt
from scipy.special import wofz

# --- Plotting Setup ---
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman"],
    "font.size": 14,
    "axes.titlesize": 16,
    "axes.labelsize": 14,
})

# --- Core Physics Functions ---

def f_norm(Omega):
    """
    Unruh-mode normalization:
      f(Ω) = exp(-πΩ/2) / sqrt(8π |Ω| sinh(π|Ω|))
    Works for scalars/arrays; avoids Ω=0 singularity with a tiny epsilon.
    """
    Om = np.asarray(Omega, dtype=np.float64)
    absOm = np.abs(Om)
    absOm = np.where(absOm == 0.0, 1e-16, absOm)  # protect at 0
    denom = np.sqrt(8.0 * np.pi * absOm * np.sinh(np.pi * absOm))
    return np.exp(-0.5 * np.pi * Om) / denom

def M_unified(Omega, Omega_p, chi, chi_p, a, T, omega):
    """
    Unified finite-time amplitude (ground-state initial detector):
      M_{χχ'} ∝ χχ' ΩΩ' a^{i(Ω+Ω')} f(-χΩ) f(-χ'Ω')
               exp[-(aT)^2/4 (χΩ+χ'Ω')^2] w( T[(a/2)(χΩ-χ'Ω') - ω] )
    """
    g = 1.0
    hbar = 1.0

    Omega   = np.asarray(Omega,   dtype=np.float64)
    Omega_p = np.asarray(Omega_p, dtype=np.float64)

    prefactor = (g**2 * a**2) / (2.0 * hbar**2) * chi * chi_p * Omega * Omega_p
    f_product = f_norm(-chi * Omega) * f_norm(-chi_p * Omega_p)
    phase_factor = np.exp(1j * (Omega + Omega_p) * np.log(a))
    gaussian_term = np.exp(-0.25 * (a**2) * (T**2) * (chi*Omega + chi_p*Omega_p)**2)

    # Faddeeva argument; cast to complex for wofz
    faddeeva_arg = T * (0.5 * a * (chi*Omega - chi_p*Omega_p) - omega)
    faddeeva_term = wofz(faddeeva_arg.astype(np.complex128))

    return prefactor * phase_factor * f_product * gaussian_term * faddeeva_term

# --- Main Plotting Script ---

omega = 1.0
a = 1.0
T = 5.0
n_points = 250

Omega_vals = np.linspace(-3.5, 3.5, n_points)
Op_vals    = np.linspace(-3.5, 3.5, n_points)
O_grid, Op_grid = np.meshgrid(Omega_vals, Op_vals, indexing='xy')

Z_RR = M_unified(O_grid, Op_grid,  1,  1, a, T, omega)
Z_LL = M_unified(O_grid, Op_grid, -1, -1, a, T, omega)
Z_RL = M_unified(O_grid, Op_grid,  1, -1, a, T, omega)
Z_LR = M_unified(O_grid, Op_grid, -1,  1, a, T, omega)

Prob_RR = np.abs(Z_RR)**2
Prob_LL = np.abs(Z_LL)**2
Prob_RL = np.abs(Z_RL)**2
Prob_LR = np.abs(Z_LR)**2

# Data + titles
plot_data = [
    (Prob_RR, r'$\mathbf{(a)\ RR\ Channel}$'),
    (Prob_LL, r'$\mathbf{(b)\ LL\ Channel}$'),
    (Prob_RL, r'$\mathbf{(c)\ RL\ Channel}$'),
    (Prob_LR, r'$\mathbf{(d)\ LR\ Channel}$'),
]

fig, axs = plt.subplots(2, 2, figsize=(11, 10), constrained_layout=True)
fig.suptitle(r'\textbf{Two-Photon Emission Spectra with Gaussian Switching}', fontsize=18)

for i, (data, title) in enumerate(plot_data):
    ax = axs.flat[i]
    # Normalize per-panel to [0,1] for visual comparison
    norm_data = data / np.max(data)
    # Use jet as requested
    im = ax.pcolormesh(O_grid, Op_grid, norm_data, cmap='jet', shading='auto', vmin=0.0, vmax=1.0)
    ax.contour(O_grid, Op_grid, norm_data, levels=[0.5, 0.8], colors='white', linewidths=0.7, alpha=0.6)

    ax.set_title(title)
    ax.set_xlabel(r'$\mathbf{\Omega}$')
    ax.set_ylabel(r"$\mathbf{\Omega'}$")
    ax.axhline(0, color='white', linestyle='--', linewidth=0.5, alpha=0.5)
    ax.axvline(0, color='white', linestyle='--', linewidth=0.5, alpha=0.5)
    ax.set_aspect('equal')

cbar = fig.colorbar(im, ax=axs.ravel().tolist(), shrink=0.8, pad=0.03)
cbar.set_label(r'\textbf{Normalized Probability Density} $|\mathcal{M}|^2$')

output_filename = 'figure_channels.png'
plt.savefig(output_filename, dpi=300, bbox_inches='tight')
print(f"✅ Figure 1 has been successfully generated and saved as '{output_filename}'")
